{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c9e1048",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a44c333",
   "metadata": {},
   "source": [
    "# Hamilton + Plotly integration [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/plotly/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/plotly/notebook.ipynb)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ec11f041",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-20T06:21:27.406089Z",
     "start_time": "2023-11-20T06:21:25.198718Z"
    }
   },
   "outputs": [],
   "source": [
    "import model_training\n",
    "\n",
    "from hamilton import driver\n",
    "from hamilton.io.materialization import to"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "544b3ab5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-20T06:21:27.443007Z",
     "start_time": "2023-11-20T06:21:27.440097Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/apache/hamilton#usage-analytics--data-privacy for details.\n"
     ]
    }
   ],
   "source": [
    "dag_config = {\n",
    "    \"test_size_fraction\": 0.95,\n",
    "    \"shuffle_train_test_split\": True,\n",
    "    \"data_loader\" : \"digits\",\n",
    "    \"clf\" : \"svm\",\n",
    "    \"penalty\" : \"l2\"\n",
    "}\n",
    "dr = (\n",
    "    driver.Builder()\n",
    "    .with_config(dag_config)\n",
    "    .with_modules(model_training)\n",
    "    .build()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f1f99ac1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-20T06:21:27.907847Z",
     "start_time": "2023-11-20T06:21:27.440712Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 8.0.5 (20230430.1635)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"1775pt\" height=\"578pt\"\n",
       " viewBox=\"0.00 0.00 1775.00 577.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 573.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-573.8 1771,-573.8 1771,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"8,-317.8 8,-561.8 116.1,-561.8 116.1,-317.8 8,-317.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"62.05\" y=\"-544.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- y_train -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>y_train</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M619.65,-145.6C619.65,-145.6 576.3,-145.6 576.3,-145.6 570.3,-145.6 564.3,-139.6 564.3,-133.6 564.3,-133.6 564.3,-94 564.3,-94 564.3,-88 570.3,-82 576.3,-82 576.3,-82 619.65,-82 619.65,-82 625.65,-82 631.65,-88 631.65,-94 631.65,-94 631.65,-133.6 631.65,-133.6 631.65,-139.6 625.65,-145.6 619.65,-145.6\"/>\n",
       "<text text-anchor=\"start\" x=\"575.85\" y=\"-122.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">y_train</text>\n",
       "<text text-anchor=\"start\" x=\"575.1\" y=\"-94.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- fit_clf -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>fit_clf</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M779,-186.6C779,-186.6 693.65,-186.6 693.65,-186.6 687.65,-186.6 681.65,-180.6 681.65,-174.6 681.65,-174.6 681.65,-135 681.65,-135 681.65,-129 687.65,-123 693.65,-123 693.65,-123 779,-123 779,-123 785,-123 791,-129 791,-135 791,-135 791,-174.6 791,-174.6 791,-180.6 785,-186.6 779,-186.6\"/>\n",
       "<text text-anchor=\"start\" x=\"718.32\" y=\"-163.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">fit_clf</text>\n",
       "<text text-anchor=\"start\" x=\"692.45\" y=\"-135.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ClassifierMixin</text>\n",
       "</g>\n",
       "<!-- y_train&#45;&gt;fit_clf -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>y_train&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M631.98,-123.72C643.79,-127.28 657.53,-131.41 671.07,-135.48\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"669.77,-139.04 680.35,-138.57 671.78,-132.34 669.77,-139.04\"/>\n",
       "</g>\n",
       "<!-- y_test -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>y_test</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M907.6,-369.6C907.6,-369.6 864.25,-369.6 864.25,-369.6 858.25,-369.6 852.25,-363.6 852.25,-357.6 852.25,-357.6 852.25,-318 852.25,-318 852.25,-312 858.25,-306 864.25,-306 864.25,-306 907.6,-306 907.6,-306 913.6,-306 919.6,-312 919.6,-318 919.6,-318 919.6,-357.6 919.6,-357.6 919.6,-363.6 913.6,-369.6 907.6,-369.6\"/>\n",
       "<text text-anchor=\"start\" x=\"866.42\" y=\"-346.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">y_test</text>\n",
       "<text text-anchor=\"start\" x=\"863.05\" y=\"-318.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- y_test_with_labels -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>y_test_with_labels</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1146.07,-391.6C1146.07,-391.6 1028.47,-391.6 1028.47,-391.6 1022.47,-391.6 1016.47,-385.6 1016.47,-379.6 1016.47,-379.6 1016.47,-340 1016.47,-340 1016.47,-334 1022.47,-328 1028.47,-328 1028.47,-328 1146.07,-328 1146.07,-328 1152.07,-328 1158.07,-334 1158.07,-340 1158.07,-340 1158.07,-379.6 1158.07,-379.6 1158.07,-385.6 1152.07,-391.6 1146.07,-391.6\"/>\n",
       "<text text-anchor=\"start\" x=\"1027.27\" y=\"-368.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">y_test_with_labels</text>\n",
       "<text text-anchor=\"start\" x=\"1064.4\" y=\"-340.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- y_test&#45;&gt;y_test_with_labels -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>y_test&#45;&gt;y_test_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M919.84,-341.43C943.09,-344 975.3,-347.55 1005.19,-350.85\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1004.76,-354.44 1015.08,-352.05 1005.52,-347.48 1004.76,-354.44\"/>\n",
       "</g>\n",
       "<!-- confusion_matrix_png -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>confusion_matrix_png</title>\n",
       "<path fill=\"#ffc857\" stroke=\"black\" d=\"M1764.75,-478.57C1764.75,-482.96 1727.3,-486.52 1681.2,-486.52 1635.1,-486.52 1597.65,-482.96 1597.65,-478.57 1597.65,-478.57 1597.65,-407.03 1597.65,-407.03 1597.65,-402.64 1635.1,-399.08 1681.2,-399.08 1727.3,-399.08 1764.75,-402.64 1764.75,-407.03 1764.75,-407.03 1764.75,-478.57 1764.75,-478.57\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1764.75,-478.57C1764.75,-474.19 1727.3,-470.62 1681.2,-470.62 1635.1,-470.62 1597.65,-474.19 1597.65,-478.57\"/>\n",
       "<text text-anchor=\"start\" x=\"1608.45\" y=\"-451.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">confusion_matrix_png</text>\n",
       "<text text-anchor=\"start\" x=\"1629.82\" y=\"-423.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PlotlyStaticWriter</text>\n",
       "</g>\n",
       "<!-- prefit_clf -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>prefit_clf</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M640.65,-63.6C640.65,-63.6 555.3,-63.6 555.3,-63.6 549.3,-63.6 543.3,-57.6 543.3,-51.6 543.3,-51.6 543.3,-12 543.3,-12 543.3,-6 549.3,0 555.3,0 555.3,0 640.65,0 640.65,0 646.65,0 652.65,-6 652.65,-12 652.65,-12 652.65,-51.6 652.65,-51.6 652.65,-57.6 646.65,-63.6 640.65,-63.6\"/>\n",
       "<text text-anchor=\"start\" x=\"569.47\" y=\"-40.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">prefit_clf</text>\n",
       "<text text-anchor=\"start\" x=\"554.1\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ClassifierMixin</text>\n",
       "</g>\n",
       "<!-- prefit_clf&#45;&gt;fit_clf -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>prefit_clf&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M642.14,-63.99C645.75,-66.91 649.29,-69.87 652.65,-72.8 667.82,-86.02 683.7,-101.36 697.44,-115.18\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"694.32,-117.01 703.83,-121.67 699.3,-112.09 694.32,-117.01\"/>\n",
       "</g>\n",
       "<!-- confusion_matrix_html -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>confusion_matrix_html</title>\n",
       "<path fill=\"#ffc857\" stroke=\"black\" d=\"M1767,-373.57C1767,-377.96 1728.54,-381.52 1681.2,-381.52 1633.86,-381.52 1595.4,-377.96 1595.4,-373.57 1595.4,-373.57 1595.4,-302.03 1595.4,-302.03 1595.4,-297.64 1633.86,-294.08 1681.2,-294.08 1728.54,-294.08 1767,-297.64 1767,-302.03 1767,-302.03 1767,-373.57 1767,-373.57\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1767,-373.57C1767,-369.19 1728.54,-365.62 1681.2,-365.62 1633.86,-365.62 1595.4,-369.19 1595.4,-373.57\"/>\n",
       "<text text-anchor=\"start\" x=\"1606.2\" y=\"-346.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">confusion_matrix_html</text>\n",
       "<text text-anchor=\"start\" x=\"1615.57\" y=\"-318.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PlotlyInteractiveWriter</text>\n",
       "</g>\n",
       "<!-- predicted_output -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>predicted_output</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M939.85,-287.6C939.85,-287.6 832,-287.6 832,-287.6 826,-287.6 820,-281.6 820,-275.6 820,-275.6 820,-236 820,-236 820,-230 826,-224 832,-224 832,-224 939.85,-224 939.85,-224 945.85,-224 951.85,-230 951.85,-236 951.85,-236 951.85,-275.6 951.85,-275.6 951.85,-281.6 945.85,-287.6 939.85,-287.6\"/>\n",
       "<text text-anchor=\"start\" x=\"830.8\" y=\"-264.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">predicted_output</text>\n",
       "<text text-anchor=\"start\" x=\"863.05\" y=\"-236.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- fit_clf&#45;&gt;predicted_output -->\n",
       "<g id=\"edge19\" class=\"edge\">\n",
       "<title>fit_clf&#45;&gt;predicted_output</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M784.43,-187.03C798.55,-196.69 814.16,-207.38 828.8,-217.39\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"826.53,-220.77 836.76,-223.53 830.48,-214.99 826.53,-220.77\"/>\n",
       "</g>\n",
       "<!-- train_test_split_func -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>train_test_split_func</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M502.3,-307.6C502.3,-307.6 372.7,-307.6 372.7,-307.6 366.7,-307.6 360.7,-301.6 360.7,-295.6 360.7,-295.6 360.7,-256 360.7,-256 360.7,-250 366.7,-244 372.7,-244 372.7,-244 502.3,-244 502.3,-244 508.3,-244 514.3,-250 514.3,-256 514.3,-256 514.3,-295.6 514.3,-295.6 514.3,-301.6 508.3,-307.6 502.3,-307.6\"/>\n",
       "<text text-anchor=\"start\" x=\"371.5\" y=\"-284.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">train_test_split_func</text>\n",
       "<text text-anchor=\"start\" x=\"427\" y=\"-256.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;y_train -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;y_train</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M463.34,-243.63C483.61,-218.29 513.64,-182.69 543.3,-154.8 547.2,-151.13 551.43,-147.46 555.73,-143.91\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"557.41,-146.25 563.03,-137.27 553.03,-140.78 557.41,-146.25\"/>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;y_test -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;y_test</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M507.92,-308.03C519.52,-312.33 531.6,-316.15 543.3,-318.8 648.18,-342.52 775.56,-342.2 841.17,-339.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"841.06,-343.44 850.92,-339.57 840.8,-336.44 841.06,-343.44\"/>\n",
       "</g>\n",
       "<!-- X_test -->\n",
       "<g id=\"node17\" class=\"node\">\n",
       "<title>X_test</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M619.65,-309.6C619.65,-309.6 576.3,-309.6 576.3,-309.6 570.3,-309.6 564.3,-303.6 564.3,-297.6 564.3,-297.6 564.3,-258 564.3,-258 564.3,-252 570.3,-246 576.3,-246 576.3,-246 619.65,-246 619.65,-246 625.65,-246 631.65,-252 631.65,-258 631.65,-258 631.65,-297.6 631.65,-297.6 631.65,-303.6 625.65,-309.6 619.65,-309.6\"/>\n",
       "<text text-anchor=\"start\" x=\"577.72\" y=\"-286.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">X_test</text>\n",
       "<text text-anchor=\"start\" x=\"575.1\" y=\"-258.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;X_test -->\n",
       "<g id=\"edge25\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;X_test</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M514.49,-276.76C527.87,-276.93 541.37,-277.1 553.47,-277.25\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"553.1,-280.76 563.14,-277.39 553.19,-273.76 553.1,-280.76\"/>\n",
       "</g>\n",
       "<!-- X_train -->\n",
       "<g id=\"node18\" class=\"node\">\n",
       "<title>X_train</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M619.65,-227.6C619.65,-227.6 576.3,-227.6 576.3,-227.6 570.3,-227.6 564.3,-221.6 564.3,-215.6 564.3,-215.6 564.3,-176 564.3,-176 564.3,-170 570.3,-164 576.3,-164 576.3,-164 619.65,-164 619.65,-164 625.65,-164 631.65,-170 631.65,-176 631.65,-176 631.65,-215.6 631.65,-215.6 631.65,-221.6 625.65,-227.6 619.65,-227.6\"/>\n",
       "<text text-anchor=\"start\" x=\"575.1\" y=\"-204.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">X_train</text>\n",
       "<text text-anchor=\"start\" x=\"575.1\" y=\"-176.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- train_test_split_func&#45;&gt;X_train -->\n",
       "<g id=\"edge26\" class=\"edge\">\n",
       "<title>train_test_split_func&#45;&gt;X_train</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M502.04,-243.72C519.59,-234.87 538.2,-225.47 554.25,-217.37\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"555.53,-220.14 562.88,-212.51 552.38,-213.89 555.53,-220.14\"/>\n",
       "</g>\n",
       "<!-- data -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>data</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M79.97,-307.6C79.97,-307.6 44.12,-307.6 44.12,-307.6 38.12,-307.6 32.12,-301.6 32.12,-295.6 32.12,-295.6 32.12,-256 32.12,-256 32.12,-250 38.12,-244 44.12,-244 44.12,-244 79.97,-244 79.97,-244 85.97,-244 91.97,-250 91.97,-256 91.97,-256 91.97,-295.6 91.97,-295.6 91.97,-301.6 85.97,-307.6 79.97,-307.6\"/>\n",
       "<text text-anchor=\"start\" x=\"48.17\" y=\"-284.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data</text>\n",
       "<text text-anchor=\"start\" x=\"42.92\" y=\"-256.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Bunch</text>\n",
       "</g>\n",
       "<!-- feature_matrix -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>feature_matrix</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M280.07,-225.6C280.07,-225.6 188.72,-225.6 188.72,-225.6 182.72,-225.6 176.72,-219.6 176.72,-213.6 176.72,-213.6 176.72,-174 176.72,-174 176.72,-168 182.72,-162 188.72,-162 188.72,-162 280.07,-162 280.07,-162 286.07,-162 292.07,-168 292.07,-174 292.07,-174 292.07,-213.6 292.07,-213.6 292.07,-219.6 286.07,-225.6 280.07,-225.6\"/>\n",
       "<text text-anchor=\"start\" x=\"187.52\" y=\"-202.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">feature_matrix</text>\n",
       "<text text-anchor=\"start\" x=\"211.52\" y=\"-174.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- data&#45;&gt;feature_matrix -->\n",
       "<g id=\"edge14\" class=\"edge\">\n",
       "<title>data&#45;&gt;feature_matrix</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M92.23,-258.73C105.71,-251.08 122.05,-242.15 137.1,-234.8 146.53,-230.2 156.61,-225.58 166.56,-221.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"167.66,-224.1 175.44,-216.9 164.87,-217.68 167.66,-224.1\"/>\n",
       "</g>\n",
       "<!-- target -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>target</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M256.07,-307.6C256.07,-307.6 212.72,-307.6 212.72,-307.6 206.72,-307.6 200.72,-301.6 200.72,-295.6 200.72,-295.6 200.72,-256 200.72,-256 200.72,-250 206.72,-244 212.72,-244 212.72,-244 256.07,-244 256.07,-244 262.07,-244 268.07,-250 268.07,-256 268.07,-256 268.07,-295.6 268.07,-295.6 268.07,-301.6 262.07,-307.6 256.07,-307.6\"/>\n",
       "<text text-anchor=\"start\" x=\"215.65\" y=\"-284.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">target</text>\n",
       "<text text-anchor=\"start\" x=\"211.52\" y=\"-256.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- data&#45;&gt;target -->\n",
       "<g id=\"edge21\" class=\"edge\">\n",
       "<title>data&#45;&gt;target</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M92.33,-275.8C119.11,-275.8 159.07,-275.8 189.63,-275.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"189.43,-279.3 199.43,-275.8 189.43,-272.3 189.43,-279.3\"/>\n",
       "</g>\n",
       "<!-- target_names -->\n",
       "<g id=\"node15\" class=\"node\">\n",
       "<title>target_names</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M928.97,-451.6C928.97,-451.6 842.87,-451.6 842.87,-451.6 836.87,-451.6 830.87,-445.6 830.87,-439.6 830.87,-439.6 830.87,-400 830.87,-400 830.87,-394 836.87,-388 842.87,-388 842.87,-388 928.97,-388 928.97,-388 934.97,-388 940.97,-394 940.97,-400 940.97,-400 940.97,-439.6 940.97,-439.6 940.97,-445.6 934.97,-451.6 928.97,-451.6\"/>\n",
       "<text text-anchor=\"start\" x=\"841.67\" y=\"-428.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">target_names</text>\n",
       "<text text-anchor=\"start\" x=\"863.05\" y=\"-400.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- data&#45;&gt;target_names -->\n",
       "<g id=\"edge22\" class=\"edge\">\n",
       "<title>data&#45;&gt;target_names</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M92.47,-291.39C101.23,-297.34 110.09,-304.88 116.1,-313.8 138.33,-346.79 106.64,-375.22 137.1,-400.8 239.19,-486.56 303.17,-419.8 436.5,-419.8 436.5,-419.8 436.5,-419.8 598.97,-419.8 674.76,-419.8 761.85,-419.8 819.81,-419.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"819.64,-423.3 829.64,-419.8 819.64,-416.3 819.64,-423.3\"/>\n",
       "</g>\n",
       "<!-- confusion_matrix -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>confusion_matrix</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1345.55,-391.6C1345.55,-391.6 1234.7,-391.6 1234.7,-391.6 1228.7,-391.6 1222.7,-385.6 1222.7,-379.6 1222.7,-379.6 1222.7,-340 1222.7,-340 1222.7,-334 1228.7,-328 1234.7,-328 1234.7,-328 1345.55,-328 1345.55,-328 1351.55,-328 1357.55,-334 1357.55,-340 1357.55,-340 1357.55,-379.6 1357.55,-379.6 1357.55,-385.6 1351.55,-391.6 1345.55,-391.6\"/>\n",
       "<text text-anchor=\"start\" x=\"1233.5\" y=\"-368.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">confusion_matrix</text>\n",
       "<text text-anchor=\"start\" x=\"1267.25\" y=\"-340.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- y_test_with_labels&#45;&gt;confusion_matrix -->\n",
       "<g id=\"edge16\" class=\"edge\">\n",
       "<title>y_test_with_labels&#45;&gt;confusion_matrix</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1158.52,-359.8C1175.67,-359.8 1194.1,-359.8 1211.57,-359.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1211.29,-363.3 1221.29,-359.8 1211.29,-356.3 1211.29,-363.3\"/>\n",
       "</g>\n",
       "<!-- feature_matrix&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>feature_matrix&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M292.49,-217.08C310.49,-224.42 330.78,-232.69 350.34,-240.67\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"348.9,-244.27 359.48,-244.8 351.54,-237.79 348.9,-244.27\"/>\n",
       "</g>\n",
       "<!-- confusion_matrix_figure -->\n",
       "<g id=\"node16\" class=\"node\">\n",
       "<title>confusion_matrix_figure</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1554.4,-421.6C1554.4,-421.6 1398.55,-421.6 1398.55,-421.6 1392.55,-421.6 1386.55,-415.6 1386.55,-409.6 1386.55,-409.6 1386.55,-370 1386.55,-370 1386.55,-364 1392.55,-358 1398.55,-358 1398.55,-358 1554.4,-358 1554.4,-358 1560.4,-358 1566.4,-364 1566.4,-370 1566.4,-370 1566.4,-409.6 1566.4,-409.6 1566.4,-415.6 1560.4,-421.6 1554.4,-421.6\"/>\n",
       "<text text-anchor=\"start\" x=\"1397.35\" y=\"-398.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">confusion_matrix_figure</text>\n",
       "<text text-anchor=\"start\" x=\"1457.35\" y=\"-370.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Figure</text>\n",
       "</g>\n",
       "<!-- confusion_matrix&#45;&gt;confusion_matrix_figure -->\n",
       "<g id=\"edge23\" class=\"edge\">\n",
       "<title>confusion_matrix&#45;&gt;confusion_matrix_figure</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1357.94,-370.67C1363.67,-371.61 1369.55,-372.56 1375.46,-373.52\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1374.71,-377.11 1385.14,-375.26 1375.83,-370.2 1374.71,-377.11\"/>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>predicted_output_with_labels</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1181.7,-309.6C1181.7,-309.6 992.85,-309.6 992.85,-309.6 986.85,-309.6 980.85,-303.6 980.85,-297.6 980.85,-297.6 980.85,-258 980.85,-258 980.85,-252 986.85,-246 992.85,-246 992.85,-246 1181.7,-246 1181.7,-246 1187.7,-246 1193.7,-252 1193.7,-258 1193.7,-258 1193.7,-297.6 1193.7,-297.6 1193.7,-303.6 1187.7,-309.6 1181.7,-309.6\"/>\n",
       "<text text-anchor=\"start\" x=\"991.65\" y=\"-286.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">predicted_output_with_labels</text>\n",
       "<text text-anchor=\"start\" x=\"1064.4\" y=\"-258.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- predicted_output_with_labels&#45;&gt;confusion_matrix -->\n",
       "<g id=\"edge15\" class=\"edge\">\n",
       "<title>predicted_output_with_labels&#45;&gt;confusion_matrix</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1172.11,-310.1C1179.42,-313.01 1186.7,-315.94 1193.7,-318.8 1199.78,-321.29 1206.06,-323.9 1212.36,-326.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1210.68,-330.05 1221.25,-330.72 1213.4,-323.61 1210.68,-330.05\"/>\n",
       "</g>\n",
       "<!-- predicted_output&#45;&gt;predicted_output_with_labels -->\n",
       "<g id=\"edge17\" class=\"edge\">\n",
       "<title>predicted_output&#45;&gt;predicted_output_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M952.2,-263C957.88,-263.63 963.72,-264.28 969.65,-264.93\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"969.07,-268.5 979.4,-266.12 969.84,-261.54 969.07,-268.5\"/>\n",
       "</g>\n",
       "<!-- target&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>target&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M268.16,-275.8C290.26,-275.8 320.55,-275.8 349.35,-275.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"349.29,-279.3 359.29,-275.8 349.29,-272.3 349.29,-279.3\"/>\n",
       "</g>\n",
       "<!-- target_names&#45;&gt;y_test_with_labels -->\n",
       "<g id=\"edge13\" class=\"edge\">\n",
       "<title>target_names&#45;&gt;y_test_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M941.4,-403.4C961.21,-397.44 984.05,-390.57 1005.65,-384.07\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1006.56,-387.15 1015.13,-380.91 1004.55,-380.44 1006.56,-387.15\"/>\n",
       "</g>\n",
       "<!-- target_names&#45;&gt;predicted_output_with_labels -->\n",
       "<g id=\"edge18\" class=\"edge\">\n",
       "<title>target_names&#45;&gt;predicted_output_with_labels</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M941.28,-388.94C945.07,-385.76 948.65,-382.38 951.85,-378.8 971.58,-356.71 959.09,-338.9 980.85,-318.8 981.72,-318 982.6,-317.22 983.5,-316.45\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"985.2,-318.78 990.99,-309.91 980.93,-313.24 985.2,-318.78\"/>\n",
       "</g>\n",
       "<!-- target_names&#45;&gt;confusion_matrix_figure -->\n",
       "<g id=\"edge24\" class=\"edge\">\n",
       "<title>target_names&#45;&gt;confusion_matrix_figure</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M941.33,-418.53C1029.01,-416.28 1207.02,-410.85 1357.55,-400.8 1363.38,-400.41 1369.38,-399.97 1375.42,-399.5\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1375.65,-402.91 1385.34,-398.62 1375.09,-395.93 1375.65,-402.91\"/>\n",
       "</g>\n",
       "<!-- confusion_matrix_figure&#45;&gt;confusion_matrix_png -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>confusion_matrix_figure&#45;&gt;confusion_matrix_png</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1566.69,-413.12C1573.38,-414.87 1580.14,-416.64 1586.86,-418.4\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1585.8,-422 1596.36,-421.14 1587.57,-415.23 1585.8,-422\"/>\n",
       "</g>\n",
       "<!-- confusion_matrix_figure&#45;&gt;confusion_matrix_html -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>confusion_matrix_figure&#45;&gt;confusion_matrix_html</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1566.69,-366.92C1572.6,-365.4 1578.56,-363.87 1584.5,-362.35\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1585.33,-365.49 1594.14,-359.62 1583.59,-358.71 1585.33,-365.49\"/>\n",
       "</g>\n",
       "<!-- X_test&#45;&gt;predicted_output -->\n",
       "<g id=\"edge20\" class=\"edge\">\n",
       "<title>X_test&#45;&gt;predicted_output</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M632.06,-275.25C675.01,-271.95 751.28,-266.08 808.86,-261.65\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"809.05,-265.07 818.75,-260.81 808.51,-258.09 809.05,-265.07\"/>\n",
       "</g>\n",
       "<!-- X_train&#45;&gt;fit_clf -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>X_train&#45;&gt;fit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M631.98,-185.88C643.79,-182.32 657.53,-178.19 671.07,-174.12\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"671.78,-177.26 680.35,-171.03 669.77,-170.56 671.78,-177.26\"/>\n",
       "</g>\n",
       "<!-- _prefit_clf_inputs -->\n",
       "<g id=\"node19\" class=\"node\">\n",
       "<title>_prefit_clf_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"491.3,-54.1 383.7,-54.1 383.7,-9.5 491.3,-9.5 491.3,-54.1\"/>\n",
       "<text text-anchor=\"start\" x=\"398.5\" y=\"-26\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gamma</text>\n",
       "<text text-anchor=\"start\" x=\"451\" y=\"-26\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">float</text>\n",
       "</g>\n",
       "<!-- _prefit_clf_inputs&#45;&gt;prefit_clf -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>_prefit_clf_inputs&#45;&gt;prefit_clf</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M491.71,-31.8C504.69,-31.8 518.73,-31.8 532.19,-31.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"531.82,-35.3 541.82,-31.8 531.82,-28.3 531.82,-35.3\"/>\n",
       "</g>\n",
       "<!-- _train_test_split_func_inputs -->\n",
       "<g id=\"node20\" class=\"node\">\n",
       "<title>_train_test_split_func_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"331.7,-391.6 137.1,-391.6 137.1,-326 331.7,-326 331.7,-391.6\"/>\n",
       "<text text-anchor=\"start\" x=\"151.77\" y=\"-363.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">shuffle_train_test_split</text>\n",
       "<text text-anchor=\"start\" x=\"291.65\" y=\"-363.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n",
       "<text text-anchor=\"start\" x=\"166.4\" y=\"-342.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">test_size_fraction</text>\n",
       "<text text-anchor=\"start\" x=\"291.65\" y=\"-342.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">float</text>\n",
       "</g>\n",
       "<!-- _train_test_split_func_inputs&#45;&gt;train_test_split_func -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>_train_test_split_func_inputs&#45;&gt;train_test_split_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M316.04,-325.52C327.36,-320.84 339.04,-316.03 350.46,-311.31\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"351.51,-314.25 359.41,-307.2 348.84,-307.78 351.51,-314.25\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node21\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"89.05,-531.1 35.05,-531.1 35.05,-494.5 89.05,-494.5 89.05,-531.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"62.05\" y=\"-507\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node22\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M84.47,-476.1C84.47,-476.1 39.62,-476.1 39.62,-476.1 33.62,-476.1 27.62,-470.1 27.62,-464.1 27.62,-464.1 27.62,-451.5 27.62,-451.5 27.62,-445.5 33.62,-439.5 39.62,-439.5 39.62,-439.5 84.47,-439.5 84.47,-439.5 90.47,-439.5 96.47,-445.5 96.47,-451.5 96.47,-451.5 96.47,-464.1 96.47,-464.1 96.47,-470.1 90.47,-476.1 84.47,-476.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"62.05\" y=\"-452\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "<!-- output -->\n",
       "<g id=\"node23\" class=\"node\">\n",
       "<title>output</title>\n",
       "<path fill=\"#ffc857\" stroke=\"black\" d=\"M79.6,-421.1C79.6,-421.1 44.5,-421.1 44.5,-421.1 38.5,-421.1 32.5,-415.1 32.5,-409.1 32.5,-409.1 32.5,-396.5 32.5,-396.5 32.5,-390.5 38.5,-384.5 44.5,-384.5 44.5,-384.5 79.6,-384.5 79.6,-384.5 85.6,-384.5 91.6,-390.5 91.6,-396.5 91.6,-396.5 91.6,-409.1 91.6,-409.1 91.6,-415.1 85.6,-421.1 79.6,-421.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"62.05\" y=\"-397\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">output</text>\n",
       "</g>\n",
       "<!-- materializer -->\n",
       "<g id=\"node24\" class=\"node\">\n",
       "<title>materializer</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M108.1,-362.34C108.1,-364.37 87.46,-366.01 62.05,-366.01 36.64,-366.01 16,-364.37 16,-362.34 16,-362.34 16,-329.26 16,-329.26 16,-327.23 36.64,-325.59 62.05,-325.59 87.46,-325.59 108.1,-327.23 108.1,-329.26 108.1,-329.26 108.1,-362.34 108.1,-362.34\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M108.1,-362.34C108.1,-360.31 87.46,-358.66 62.05,-358.66 36.64,-358.66 16,-360.31 16,-362.34\"/>\n",
       "<text text-anchor=\"middle\" x=\"62.05\" y=\"-340\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">materializer</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x146a51850>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "materializers = [\n",
    "        to.plotly(\n",
    "            dependencies=[\"confusion_matrix_figure\"],\n",
    "            id=\"confusion_matrix_png\",\n",
    "            path=\"./static.png\",\n",
    "        ),\n",
    "        to.html(\n",
    "            dependencies=[\"confusion_matrix_figure\"],\n",
    "            id=\"confusion_matrix_html\",\n",
    "            path=\"./interactive.html\",\n",
    "        ),\n",
    "    ]\n",
    "\n",
    "dr.visualize_materialization(*materializers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4298a4b9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-20T06:21:29.472853Z",
     "start_time": "2023-11-20T06:21:27.912356Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'confusion_matrix_png': {'size': 29058,\n",
       "   'path': './static.png',\n",
       "   'last_modified': 1700461289.1922433,\n",
       "   'timestamp': 1700490089.192551},\n",
       "  'confusion_matrix_html': {'size': 3607064,\n",
       "   'path': './interactive.html',\n",
       "   'last_modified': 1700461289.2231884,\n",
       "   'timestamp': 1700490089.425375}},\n",
       " {})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dr.materialize(*materializers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dead5546",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
